/** * Copyright (C) 2010-14 diirt developers. See COPYRIGHT.TXT * All rights reserved. Use is subject to license terms. See LICENSE.TXT */ package org.diirt.pods.web; import org.diirt.pods.web.common.MessageValueEvent; import org.diirt.pods.web.common.MessageConnectionEvent; import org.diirt.pods.web.common.Message; import org.diirt.pods.web.common.MessageWriteCompletedEvent; import org.diirt.pods.web.common.MessageWrite; import org.diirt.pods.web.common.MessageSubscribe; import org.diirt.pods.web.common.MessageDecoder; import org.diirt.pods.web.common.MessageErrorEvent; import org.diirt.pods.web.common.MessageUnsubscribe; import org.diirt.pods.web.common.MessageEncoder; import org.diirt.pods.web.common.MessageResume; import org.diirt.pods.web.common.MessagePause; import java.io.InputStream; import java.time.Duration; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.logging.Level; import java.util.logging.Logger; import javax.servlet.http.HttpSession; import javax.websocket.CloseReason; import javax.websocket.EndpointConfig; import javax.websocket.OnClose; import javax.websocket.OnError; import javax.websocket.OnMessage; import javax.websocket.OnOpen; import javax.websocket.Session; import javax.websocket.server.ServerEndpoint; import org.diirt.pods.common.ChannelTranslation; import org.diirt.pods.common.ChannelTranslator; import org.diirt.util.config.Configuration; import org.diirt.util.time.TimeDuration; import org.diirt.datasource.PVManager; import org.diirt.datasource.PVReader; import org.diirt.datasource.PVReaderEvent; import org.diirt.datasource.PVReaderListener; import org.diirt.datasource.PVWriter; import org.diirt.datasource.PVWriterEvent; import org.diirt.datasource.PVWriterListener; import static org.diirt.datasource.formula.ExpressionLanguage.*; import org.diirt.datasource.formula.FormulaAst; import org.diirt.pods.common.ChannelRequest; import org.diirt.pods.web.common.MessageDecodeException; /** * * @author carcassi */ @ServerEndpoint(value = "/socket", decoders = {MessageDecoder.class}, encoders = {MessageEncoder.class}, configurator = WSEndpointConfigurator.class) public class WSEndpoint { // TODO: understand lifecycle of whole web application and put // configuration there, including closing datasources. static { ChannelTranslator temp = null; try (InputStream input = Configuration.getFileAsStream("pods/web/mappings.xml", new WSEndpoint(), "mappings.default.xml")) { temp = ChannelTranslator.loadTranslator(input); } catch (Exception ex) { Logger.getLogger(WSEndpoint.class.getName()).log(Level.SEVERE, "Couldn't load DIIRT_HOME/pods/web/mappings", ex); } channelTranslator = temp; } private static Logger log = Logger.getLogger(WSEndpoint.class.getName()); private static final ChannelTranslator channelTranslator; // XXX: need to understand how state can actually be used private final Map<Integer, PVReader<?>> channels = new ConcurrentHashMap<>(); private int defaultMaxRate; private String currentUser; private String remoteAddress; @OnMessage public void onMessage(Session session, Message message) { switch (message.getMessage()) { case SUBSCRIBE: onSubscribe(session, (MessageSubscribe) message); return; case UNSUBSCRIBE: onUnsubscribe(session, (MessageUnsubscribe) message); return; case PAUSE: onPause(session, (MessagePause) message); return; case RESUME: onResume(session, (MessageResume) message); return; case WRITE: onWrite(session, (MessageWrite) message); return; default: sendError(session, message.getId(), "Message '" + message.getMessage() + "' not supported on this server"); } } private void onSubscribe(final Session session, final MessageSubscribe message) { if (channels.get(message.getId()) != null) { sendError(session, message.getId(), "Subscription with id '" + message.getId() + "' already exists"); return; } // TODO: add maxRate check during parsing int maxRate = defaultMaxRate; if (message.getMaxRate() >= 20) { maxRate = message.getMaxRate(); } // First create the AST as seen by the client: authorization // step is based on the namespace as seen by the client FormulaAst clientAst; try { clientAst = FormulaAst.formula(message.getChannel()); } catch(RuntimeException ex) { sendError(session, message.getId(), ex.getMessage()); return; } List<String> clientChannels = clientAst.listChannelNames(); Map<String, FormulaAst> substitutions = new HashMap<>(); boolean readOnly = message.isReadOnly(); for (String clientChannel : clientChannels) { ChannelTranslation translation = channelTranslator.translate(new ChannelRequest(clientChannel, currentUser, null, null, remoteAddress)); // No channel map, return an error if (translation == null) { sendError(session, message.getId(), "Channel " + clientChannel + " does not exist"); return; } // No access to the channel, return an error if (translation.getPermission() == ChannelTranslation.Permission.NONE) { sendError(session, message.getId(), "No access to channel " + clientChannel); return; } if (!message.isReadOnly() && translation.getPermission() == ChannelTranslation.Permission.READ_ONLY) { sendError(session, message.getId(), "No write access to channel " + clientChannel); readOnly = true; } try { substitutions.put(clientChannel, FormulaAst.formula(translation.getFormula())); } catch (RuntimeException ex) { sendError(session, message.getId(), ex.getMessage()); return; } } connect(readOnly, clientAst.substituteChannels(substitutions), session, message, maxRate); } private void connect(boolean readOnly, FormulaAst ast, final Session session, final MessageSubscribe message, int maxRate) { PVReader<?> reader; if (readOnly) { reader = PVManager.read(ast.toExpression()) .readListener(new ReadOnlyListener(session, message)) .timeout(TimeDuration.ofSeconds(1.0), "Still connecting...") .maxRate(Duration.ofMillis(maxRate)); } else { ReadWriteListener readWriteListener = new ReadWriteListener(session, message); reader = PVManager.readAndWrite(formula(ast)) .readListener(readWriteListener) .writeListener(readWriteListener) .timeout(TimeDuration.ofSeconds(1.0), "Still connecting...") .asynchWriteAndMaxReadRate(Duration.ofMillis(maxRate)); } channels.put(message.getId(), reader); } private void onUnsubscribe(Session session, MessageUnsubscribe message) { PVReader<?> channel = channels.remove(message.getId()); if (channel != null) { channel.close(); } else { sendError(session, message.getId(), "Subscription with id '" + message.getId() + "' does not exist"); } } private void onPause(Session session, MessagePause message) { PVReader<?> channel = channels.get(message.getId()); if (channel != null) { channel.setPaused(true); } } private void onResume(Session session, MessageResume message) { PVReader<?> channel = channels.get(message.getId()); if (channel != null) { channel.setPaused(false); } } private void onWrite(Session session, MessageWrite message) { PVReader<?> channel = channels.get(message.getId()); if (channel instanceof PVWriter) { @SuppressWarnings("unchecked") PVWriter<Object> channelWriter = (PVWriter<Object>) channel; channelWriter.write(message.getValue()); } else { if (channel == null) { sendError(session, message.getId(), "Channel id '" + message.getId() + "' is not open"); } else { sendError(session, message.getId(), "Channel id '" + message.getId() + "' is read-only"); } } } @OnOpen public void onOpen(Session session, EndpointConfig config) { // Read the maxRate parameter String maxRate = session.getPathParameters().get("maxRate"); if (maxRate != null) { try { defaultMaxRate = Integer.parseInt(maxRate); if (defaultMaxRate < 20) { sendError(session, -1, "maxRate must be greater than 20"); defaultMaxRate = 1000; } } catch (NumberFormatException ex) { sendError(session, -1, "maxRate must be an integer"); } } else { defaultMaxRate = 1000; } // Retrive user and remote host for security purposes HttpSession httpSession = (HttpSession) config.getUserProperties().get("session"); remoteAddress = (String) httpSession.getAttribute("remoteHost"); if (session.getUserPrincipal() != null) { currentUser = session.getUserPrincipal().getName(); } else { currentUser = null; } } @OnClose public void onClose(Session session, CloseReason reason) { for (Map.Entry<Integer, PVReader<?>> entry : channels.entrySet()) { PVReader<?> channel = entry.getValue(); channel.close(); } closed = true; } private volatile boolean closed = false; @OnError public void onError(Session session, Throwable cause) { if (cause instanceof MessageDecodeException) { MessageDecodeException de = (MessageDecodeException) cause; sendError(session, de.getId(), cause.getMessage()); } else { log.log(Level.WARNING, "Unhandled exception", cause); } } public void sendError(Session session, int id, String message) { session.getAsyncRemote().sendObject(new MessageErrorEvent(id, message)); } private class ReadOnlyListener implements PVReaderListener<Object> { private final Session session; private final MessageSubscribe message; public ReadOnlyListener(Session session, MessageSubscribe message) { this.session = session; this.message = message; } @Override public void pvChanged(PVReaderEvent<Object> event) { try { if (closed) { log.log(Level.SEVERE, "Getting event after channel was closed for " + event.getPvReader().getName()); event.getPvReader().close(); return; } if (event.isConnectionChanged()) { session.getAsyncRemote().sendObject(new MessageConnectionEvent(message.getId(), event.getPvReader().isConnected(), false)); } if (event.isValueChanged()) { session.getAsyncRemote().sendObject(new MessageValueEvent(message.getId(), event.getPvReader().getValue())); } if (event.isExceptionChanged()) { session.getAsyncRemote().sendObject(new MessageErrorEvent(message.getId(), event.getPvReader().lastException().getMessage())); } } catch (RuntimeException ex) { log.log(Level.SEVERE, "Error while preparing event for " + event.getPvReader().getName(), ex); } } } private static boolean readConnected(Object channel) { @SuppressWarnings("unchecked") PVReader<Object> reader = (PVReader<Object>) channel; return reader.isConnected(); } private class ReadWriteListener implements PVReaderListener<Object>, PVWriterListener<Object> { private final Session session; private final MessageSubscribe message; public ReadWriteListener(Session session, MessageSubscribe message) { this.session = session; this.message = message; } @Override public void pvChanged(PVReaderEvent<Object> event) { try { if (closed) { log.log(Level.SEVERE, "Getting event after channel was closed for " + event.getPvReader().getName()); event.getPvReader().close(); return; } if (event.isValueChanged()) { session.getAsyncRemote().sendObject(new MessageValueEvent(message.getId(), event.getPvReader().getValue())); } if (event.isExceptionChanged()) { session.getAsyncRemote().sendObject(new MessageErrorEvent(message.getId(), event.getPvReader().lastException().getMessage())); } } catch (RuntimeException ex) { log.log(Level.SEVERE, "Error while preparing event for " + event.getPvReader().getName(), ex); } } @Override public void pvChanged(PVWriterEvent<Object> event) { try { if (closed) { log.log(Level.SEVERE, "Getting event after channel was closed for " + event.getPvWriter()); event.getPvWriter().close(); return; } if (event.isConnectionChanged()) { session.getAsyncRemote().sendObject(new MessageConnectionEvent(message.getId(), readConnected(event.getPvWriter()), event.getPvWriter().isWriteConnected())); } if (event.isWriteSucceeded()) { session.getAsyncRemote().sendObject(new MessageWriteCompletedEvent(message.getId())); } if (event.isWriteFailed()) { session.getAsyncRemote().sendObject(new MessageWriteCompletedEvent(message.getId(), event.getPvWriter().lastWriteException().getMessage())); } if (event.isExceptionChanged()) { session.getAsyncRemote().sendObject(new MessageErrorEvent(message.getId(), event.getPvWriter().lastWriteException().getMessage())); } } catch (RuntimeException ex) { log.log(Level.SEVERE, "Error while preparing event for " + event.getPvWriter(), ex); } } } }